
import dionysus as dion
import numpy as np
from scipy.optimize import linear_sum_assignment as hungarian
import math
import random
import warnings


def simplices_to_diagram(simplices):
    f = dion.Filtration(simplices)
    for vertices, time in simplices:
        f.append(dion.Simplex(vertices, time))
    f.sort()

    m = dion.homology_persistence(f)
    dgms = dion.init_diagrams(m, f)
    return dgms


def calc_W(D, M):
    # calculate all Wasserstein distances
    n = len(D)
    c = len(M)

    # W[j,k] = W_2(D_j, M_k)
    W = np.zeros((n, c))

    for j in range(n):
        for k in range(c):
            wass = calc_wasserstein(D[j], M[k])
            if wass != 0:
                W[j][k] = wass
            else:
                W[j][k] = 0.001
    return W


def calc_r(W):
    # calculate membership values
    n = np.shape(W)[0]
    c = np.shape(W)[1]
    r = np.zeros((n, c))

    for j in range(n):
        for k in range(c):
            sum = 0
            for l in range(c):
                sum += W[j][k] / W[j][l]
            r[j][k] = 1 / sum

    return r


def add_diagonals(D):
    # adds diagonals to diagrams so distance is well defined
    # diagonal denoted -1
    n = len(D)
    # c = len(M)

    # find m = max number of off-diagonal points
    m = 0
    for j in range(n):
        if len(D[j]) > m:
            m = len(D[j])
    """for k in range(c):
        if len(M[k]) > m:
            m = len(M[k])"""

    # add diagonals so every diagram has m features
    for j in range(n):
        for i in range(m - len(D[j])):
            D[j].append([-1, -1])
    """for k in range(c):
        for i in range(m - len(M[k])):
            M[k].append([-1, -1])"""

    return D


def calc_cost_matrix(Dj, Mk):
    # calculates the cost matrix for optimal transport problem
    m = len(Dj)
    if m != len(Mk):
        exit("Incompatible diagram size in calc_cost_matrix: " + str(len(Dj)) + " and " + str(len(Mk)))

    c = np.zeros((m, m))
    for i in range(m):
        for j in range(m):
            # both off-diagonal
            if Dj[i][0] != -1 and Mk[j][0] != -1:
                c[i][j] = (Dj[i][0]-Mk[j][0])**2 + (Dj[i][1]-Mk[j][1])**2
            # only Dj[i] off-diagonal
            elif Dj[i][0] != -1 and Mk[j][0] == -1:
                c[i][j] = ((Dj[i][1] - Dj[i][0]) * 1/math.sqrt(2))**2
            # only Mk[j] off-diagonal
            elif Dj[i][0] == -1 and Mk[j][0] != -1:
                c[i][j] = ((Mk[j][1] - Mk[j][0]) * 1/math.sqrt(2))**2

    return c


def calc_wasserstein(Dj, Mk):
    # calculates the 2-Wasserstein L2 distance between two diagrams
    m = len(Dj)
    c = calc_cost_matrix(Dj, Mk)
    X = hungarian(c)
    total = 0
    for i in range(m):
        total += c[X[0][i]][X[1][i]]
    return math.sqrt(total)


def calc_frechet_mean(D, r, k, verbose):
    # computes the weighted frechet mean of D with weights r[.][k]
    n = len(D)
    m = len(D[0])
    # initialise to random diagram in D
    random.seed(0)
    M_update = D[random.randint(0, n-1)]

    # first run to find matching
    matching = []
    for j in range(n):
        c = calc_cost_matrix(M_update, D[j])
        x_indices = hungarian(c)
        matching.append(x_indices)

    # loop until stopping condition is found
    counter2 = 0

    while True:
        counter2 += 1

        # update matched points
        x = np.zeros((n, m, 2))
        for j in range(n):
            for i in range(m):
                index = matching[j][1][i]
                x[j][i] = D[j][index]

        # generate y to return
        y = np.zeros((m, 2))

        # loop over each point
        for i in range(m):
            # calculate w and w_\Delta
            r2_od = 0
            r2x_od = [0, 0]
            for j in range(n):
                if x[j][i][0] != -1:
                    r2_od += r[j][k]**2
                    r2x_od[0] += r[j][k]**2 * x[j][i][0]
                    r2x_od[1] += r[j][k]**2 * x[j][i][1]

            # if all points are diagonals
            if r2_od == 0:
                # then y[i] is a diagonal
                y[i] = [-1, -1]

            # else proceed
            else:
                w = [r2x_od[0]/r2_od, r2x_od[1]/r2_od]
                w_delta = [(w[0]+w[1])/2, (w[0]+w[1])/2]

                r2_d = 0
                r2_w_delta = [0, 0]
                for j in range(n):
                    if x[j][i][0] == -1:
                        r2_d += r[j][k] ** 2
                        r2_w_delta[0] += r[j][k]**2 * w_delta[0]
                        r2_w_delta[1] += r[j][k]**2 * w_delta[1]

                # calculate weighted mean
                y[i][0] = (r2x_od[0] + r2_w_delta[0]) / (r2_od + r2_d)
                y[i][1] = (r2x_od[1] + r2_w_delta[1]) / (r2_od + r2_d)

        old_matching = matching.copy()
        matching = []
        for j in range(n):
            c = calc_cost_matrix(y, D[j])
            x_indices = hungarian(c)
            matching.append(x_indices)

        comparison = (np.array(matching) == np.array(old_matching))
        if comparison.all():
            if verbose:
                print("      Frechet iterations for M_" + str(k) + ": " + str(counter2))
            return y, x


def init_clusters(D, c):
    # initialise cluster centres to Frechet mean of two diagrams
    M = []
    ones = np.ones((len(D)+1, c+1))
    for i in range(c):
        diagram, _ = calc_frechet_mean([D[i], D[i+1]], ones, i, verbose=False)
        M.append(diagram)

    return M


def labels_to_diagrams(n_graphs, nodes, edges, node_labels, edge_labels):
    diagrams = []
    for j in range(n_graphs):
        f = dion.Filtration()
        for i in range(len(nodes[j])):
            f.append(dion.Simplex(nodes[j][i], node_labels[j][i].item()))
        for i in range(len(edges[j])):
            f.append(dion.Simplex(edges[j][i], edge_labels[j][i].item()))

        m = dion.homology_persistence(f)
        dgms = dion.init_diagrams(m, f)
        diagrams.append(dgms[0])

    return diagrams


def J(r, D, M):

    W = calc_W(D, M)

    n = np.shape(W)[0]
    c = np.shape(W)[1]

    sum = 0
    for j in range(n):
        for k in range(c):
            sum += r[j][k]**2 * W[j][k]**2

    return sum


def reformat_diagrams(D, T=100):
    # reformat from dionysus to custom structure
    # replace d=inf with hyper-parameter T
    D_new = []
    for i in range(len(D)):
        D_temp = []
        for p in D[i]:
            if np.isinf(p.death):
                D_temp.append([p.birth, T])
            else:
                D_temp.append([p.birth, p.death])
        D_new.append(D_temp)

    return D_new


def calc_dist(points):
    dist = []
    for i in range(len(points)):
        for j in range(i + 1, len(points)):
            distance = (points[i][0] - points[j][0]) ** 2 + (points[i][1] - points[j][1]) ** 2
            dist.append(distance)
            # print(str(i) + ", " + str(j) + " : " + str(distance))

    return dist


def perturb(x):
    perturb_counter = 0
    while len(x) != len(set(x)):
        # if the same, perturb
        counter = 1
        for i in range(len(x)):
            for j in range(i+1, len(x)):
                if x[j] == x[i]:
                    x[j] += 0.0001 * counter
                    counter += 1
        perturb_counter += 1

        # check for infinite loop
        if perturb_counter > 2:
            warnings.warn("Warning, perturbation iteration: " + str(perturb_counter))

    return x


def perturb_points(points):
    # the distance between each point needs to be unique

    dist = calc_dist(points)

    perturb_counter = 0
    while len(dist) != len(set(dist)):
        # for i in range(len(points)):
        #     points[i][1] += 0.0001*i

        x = points.T[0]
        y = points.T[1]

        x = perturb(x)
        y = perturb(y)

        points = np.array([x, y]).T

        dist = calc_dist(points)
        perturb_counter += 1

        # check for infinite loop
        if perturb_counter > 2:
            warnings.warn("Warning, perturbation iteration: " + str(perturb_counter))
            if perturb_counter == 10:
                exit("Perturbations OUT OF CONTROL")

    return points


def pd_fuzzy(D, c, verbose=False, max_iter=5):
    # computes fuzzy clusters of persistence diagrams
    #
    # INPUTS
    # D - list of persistence diagrams
    # c - number of clusters
    # verbose - True or False to give iteration information
    # max_iter - max number of iterations to compute
    #
    # OUTPUTS
    # r - membership values
    # M - list of cluster centres
    # W - W[j][k] is 2-Wasserstein L2 distance between Dj and Mk
    # x - optimal matchings between diagram points and centre points

    if max_iter == 0:
        exit('Maximum iterations is zero')

    D = add_diagonals(D)
    M = init_clusters(D, c)

    n = len(D)
    m = len(D[0])

    # J_new = 2 * epsilon
    # J_prev = 0

    counter = 0
    while counter < max_iter:
        if verbose:
            print("Fuzzy iteration: " + str(counter))
        counter += 1

        # update membership values
        W = calc_W(D, M)
        r = calc_r(W)

        if verbose:
            J_temp = J(r, D, M)
            print(" -- Update r -- ")
            print("   J(r, M) = " + str(J_temp))

            print(" -- Update M -- ")

        x = np.zeros((c, n, m, 2))
        # update cluster centres
        for k in range(c):
            M[k], x[k] = calc_frechet_mean(D, r, k, verbose)


        # compute J
        # J_prev = J_new
        # J_new = J(r, D, M)
        if verbose:
            J_new = J(r, D, M)
            print("   J(r, M) = " + str(J_new) + "\n")

    return r, M
